'''Train CIFAR10 with PyTorch.'''
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn 
import os.path as osp 
import numpy as np 
import random
# import timm
from pytorchcv.model_provider import get_model as ptcv_get_model
from dataset.oxford_pet import Databasket, CUB
from fastai.vision.all import *
from dataset import cifar

import torchvision
import torchvision.transforms as transforms

import os
import argparse

# from models import *
import torchvision.models as models
# import mobilenetv1
from utils import progress_bar


parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
parser.add_argument('--net', default='resnet50', type=str, help='architecture')
parser.add_argument('--bs', default=128, type=int, help='learning rate')
parser.add_argument('--seed', default=0, type=int, help='random seed')
parser.add_argument('--resume', '-r', action='store_true',
                    help='resume from checkpoint')
parser.add_argument('--multigpu', '-m', action='store_true') 
parser.add_argument('--gpu', '-g', default='0', type=str, help='gpu id')

parser.add_argument('--input_size','-i', default=224, type=int, help='input image size') 
parser.add_argument('--crop_size', default=256, type=int, help='crop image size')
parser.add_argument('--workers','-w', default=4, type=int) 

parser.add_argument('--src', default='cifar10', type=str, help='source set')
parser.add_argument('--resplit',  action='store_true')
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 
SEED = args.seed
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED) 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch

ckpt_src_dir = 'checkpoint/src/'+args.src+'/'+args.net 
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.Resize([args.crop_size, args.crop_size]),
    transforms.RandomCrop([args.input_size, args.input_size]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize([args.input_size, args.input_size]),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

if args.src == 'cifar10':

    NUM_CLS_SRC = 10
    trainset = cifar.CIFAR10(
        root='./data', train=True, download=True, transform=transform_train, resplit = args.resplit)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)

    testset = cifar.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test, resplit = args.resplit)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)

    # classes = ('plane', 'car', 'bird', 'cat', 'deer',
    #         'dog', 'frog', 'horse', 'ship', 'truck')
elif args.src == 'cifar100':

    NUM_CLS_SRC = 100
    # trainset = torchvision.datasets.CIFAR100(
    trainset = cifar.CIFAR100(
        root='./data', train=True, download=False, transform=transform_train, resplit = args.resplit)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)

    # sample_size = len(train_dataset)

    testset = cifar.CIFAR100(
        root='./data', train=False, download=False, transform=transform_test, resplit = args.resplit)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(
    #     testset, batch_size=args.bs, shuffle=False, sampler=sampler1, num_workers=args.workers)
elif args.src == 'imagenette':
    NUM_CLS_SRC = 10
    # path = untar_data(URLs.IMAGENETTE_320)
    train_root = './data/imagenette/train'
    test_root = './data/imagenette/val'
    trainset = torchvision.datasets.ImageFolder(train_root, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = torchvision.datasets.ImageFolder(train_root, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)     
        
    # pass
elif args.src == 'oxfordpets':
    NUM_CLS_SRC = 37
    databasket = Databasket(train_transforms=transform_train, val_transforms=transform_test,resplit = args.resplit)
    trainset = databasket.train_ds
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = databasket.val_ds
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False,  num_workers=args.workers)

    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, sampler = sampler1, num_workers=args.workers)
elif args.src == 'oxfordflowers':
    NUM_CLS_SRC = 102
    trainset = torchvision.datasets.Flowers102(root='./data/', split = 'train', transform = transform_train, download=False)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Flowers102(root='./data/', split = 'val', transform = transform_test, download=False)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    # exit()
    # pass
elif args.src == 'CUB':
    NUM_CLS_SRC = 200
    # if args.resplit:
    #     trainset = CUB(root='./data/CUB', is_train=False, transform=transform_train)
    #     trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)     
    #     testset = CUB(root='./data/CUB', is_train=True, transform=transform_test)
    #     testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 

    # else:
    trainset = CUB(root='./data/CUB', is_train=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)     
    testset = CUB(root='./data/CUB', is_train=False, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
    # sample_size = int(len(testset)/2) 
    # sampler1 = torch.utils.data.sampler.SubsetRandomSampler(np.random.choice(range(len(testset)), sample_size))
    # testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, sampler = sampler1, num_workers=args.workers)
    # pass

elif args.src == 'DTD':
    NUM_CLS_SRC = 47
    trainset = torchvision.datasets.DTD(root='/data/yuhe.ding/DATA/DTD', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers)
    testset = torchvision.datasets.DTD(root='./data/yuhe.ding/DATA/DTD', split = 'val', transform = transform_test, download=True)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)       
    # pass
elif args.src == 'food101':

    NUM_CLS_SRC = 101
    trainset = torchvision.datasets.Food101(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Food101(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
elif args.src == 'country211':
    NUM_CLS_SRC = 211
    trainset = torchvision.datasets.Country211(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Country211(root='./data/', split = 'valid', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers) 
elif args.src == 'place365':
    NUM_CLS_SRC = 365
    trainset = torchvision.datasets.Places365(root='./data/', split = 'train-standard', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Places365(root='./data/', split = 'val', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'stanfordcars':
    NUM_CLS_SRC = 196 
    trainset = torchvision.datasets.StanfordCars(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.StanfordCars(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'caltech101':
    NUM_CLS_SRC = 196 
    trainset = torchvision.datasets.Caltech101(root='./data/', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.Caltech101(root='./data/', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'celeba':
    NUM_CLS_SRC = 40 
    trainset = torchvision.datasets.CelebA(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.CelebA(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'fashionmnist':
    NUM_CLS_SRC = 10
    trainset = torchvision.datasets.FashionMNIST(root='./data/', train=True, transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.FashionMNIST(root='./data/', train=False, transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'svhn':
    NUM_CLS_SRC = 10
    trainset = torchvision.datasets.SVHN(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.SVHN(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'fgvcaircraft':
    NUM_CLS_SRC = 100
    trainset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.FGVCAircraft(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'gtsrb':
    NUM_CLS_SRC = 43
    trainset = torchvision.datasets.GTSRB(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.GTSRB(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'inaturalist':
    NUM_CLS_SRC = 10000
    trainset = torchvision.datasets.INaturalist(root='./data/', version = '2021_train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.INaturalist(root='./data/', version = '2021_valid', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'renderedsst2':
    NUM_CLS_SRC = 2
    trainset = torchvision.datasets.RenderedSST2(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.RenderedSST2(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)
elif args.src == 'stl10':
    NUM_CLS_SRC = 10
    trainset = torchvision.datasets.STL10(root='./data/', split = 'train', transform = transform_train, download=True)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.bs, shuffle=True, num_workers=args.workers) 
    testset = torchvision.datasets.STL10(root='./data/', split = 'test', transform = transform_test, download=True)     
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.bs, shuffle=False, num_workers=args.workers)


# Model
print('==> Building model..')
if args.net == 'resnet50':
    net = models.resnet50(pretrained=True)
    ckpt_src = ckpt_src_dir + '/resnet50-v2.pth'
elif args.net == 'resnet101':
    net = models.resnet101()
    ckpt_src = ckpt_src_dir + '/resnet101-v2.pth'
elif args.net == 'resnet152':
    net = models.resnet152()
    ckpt_src = ckpt_src_dir + '/resnet152-v2.pth'
elif args.net == 'densenet169':
    net = models.densenet169(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet169.pth' 
elif args.net == 'densenet121':
    net = models.densenet121(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet121.pth'
elif args.net == 'densenet201':
    net = models.densenet201(pretrained=True)
    if not args.src == 'imagenet':
        ckpt_src = ckpt_src_dir + '/densenet201.pth'
elif args.net == 'mobilenetv1':
    net = ptcv_get_model("mobilenet_w1", pretrained=False)
    ckpt_src = ckpt_src_dir + '/mobilenet_w1-0895-7e1d739f.pth' 
elif args.net == 'mobilenetv2':
    net = models.mobilenet_v2()
    ckpt_src = ckpt_src_dir + '/mobilenet_v2.pth'
elif args.net == 'mobilenetv3_large': 
    net = models.mobilenet_v3_large() 
    ckpt_src = ckpt_src_dir + '/mobilenet_v3_large-8738ca79.pth' 
elif args.net == 'mobilenetv3_small':  
    net = models.mobilenet_v3_small()
    ckpt_src = ckpt_src_dir + '/mobilenet_v3_small-047dcff4.pth'  
elif args.net == 'swin_b':
    net = models.swin_b()
    ckpt_src = ckpt_src_dir + '/swin_b.pth'
elif args.net == 'swin_v2_b':
    net = models.swin_v2_b()
    ckpt_src = ckpt_src_dir + '/swin_v2_b.pth'
elif args.net == 'vit_b_16':
    net = models.vit_b_16()
    ckpt_src = ckpt_src_dir + '/vit_b_16.pth'
elif args.net == 'wide_resnet101_2':
    net = models.wide_resnet101_v2()
    ckpt_src = ckpt_src_dir + '/wide_resnet101_2.pth'
elif args.net == 'efficientnetb0':
    net = models.efficientnet_b0(pretrained=True)
    ckpt_src = ckpt_src_dir + '/efficientnet_b0_rwightman-3dd342df.pth'
elif args.net == 'efficientnetb1':
    net = models.efficientnet_b1() 
    ckpt_src = ckpt_src_dir + '/efficientnet_b1_rwightman-533bc792.pth'
elif args.net == 'efficientnetb2':
    net = models.efficientnet_b2(pretrained=True)
    ckpt_src = ckpt_src_dir + '/efficientnet_b2_rwightman-bcdf34b7.pth' 
elif args.net == 'efficientnetb3':
    net = models.efficientnet_b3(pretrained=True)
    ckpt_src = ckpt_src_dir + '/efficientnet_b3_rwightman-cf984f9c.pth'
elif args.net == 'vgg16':
    net = models.vgg16()
    ckpt_src = ckpt_src_dir + '/vgg16-397923af.pth'
elif args.net == 'vgg19':
    net = models.vgg19()
    ckpt_src = ckpt_src_dir + '/vgg19-dcbb9e9d.pth'


net.fc = torch.nn.Linear(2048, NUM_CLS_SRC)

net = net.to(device)

if args.multigpu and device=='cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

if args.resume:
    # Load checkpoint.
    print('==> Resuming from checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load(ckpt_tgt)
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

def print_args(args):
    s = "==========================================\n"
    for arg, content in args.__dict__.items():
        s += "{}:{}\n".format(arg, content)
    return s 


if not os.path.isdir(ckpt_src_dir):
    os.makedirs(ckpt_src_dir)
args.out_file = open(osp.join(ckpt_src_dir, 'log.txt'), 'w')
args.out_file.write(print_args(args)+'\n')
args.out_file.flush() 

# Training
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                     % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    args.out_file.write('\nEpoch: %d' % epoch + '\n' + 'Train | Loss: %.3f | Acc: %.3f%% (%d/%d)'% (train_loss/(batch_idx+1), 100.*correct/total, correct, total)+'\n')
    args.out_file.flush()  

def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    args.out_file.write('Test | Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)+'\n')
    args.out_file.flush()   

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        torch.save(state, ckpt_src)
        best_acc = acc


for epoch in range(start_epoch, start_epoch+100):
    train(epoch)
    test(epoch)
    scheduler.step()
